#!/bin/bash

# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -euo pipefail

###############################################################################
# returns true if the arg is a decimal integer
###############################################################################
is_int() {
    # make sure there is a non-empty argument then
    # `test` succeeds if and only if $1 is an int
    [[ "${1:-}" ]] && test "${1}" -eq "${1}" 2>/dev/null 
}

###############################################################################
# returns true if arg is a runnable command
###############################################################################
have_command() {
    command -v "${1:-}" >/dev/null
}

###############################################################################
# returns true if utils-linux enhanced getopt command exists and is runnable
###############################################################################
have_enhanced_getopt() {
    have_command getopt && (( $(getopt --test; echo "$?") >= 4))
}

ngpus=1

version_string="0.1"
usage_string="usage: ${0} --help|--ngpus=<n> [--] command [args]"
help_string="
  -n, --ngpus=n             number of processes to execute (default ${ngpus})

  -v, --version             output version and exit
  -h, --help                output this help message and exit
      --usage               output brief usage message and exit"


have_enhanced_getopt || {
    echo "can not find required utils-linux enhanced getopt command" >&2
    exit 4
}

[[ "${DEBUG:-}" -eq "1" ]] && set -x

# the "+" in front of the short options indicates to stop option processing at
# the first non-recognized option.  This is appropriate here since we expect
# the user to provide a command with its own args.
tmp_args=$(getopt --name "${0}" --options "+n:hv" \
                  --longoptions "ngpus:,help,usage,version" \
                  -- "$@") || { echo "${usage_string}" >&2; exit 1; }
eval set -- "${tmp_args}"

while true; do
    case "${1:-}" in
        -h|--help)
            echo "${usage_string}"
            echo "${help_string}"
            exit 0
            ;;
        --usage)
            echo "${usage_string}"
            exit 0
            ;;
        -v|--version)
            echo "${version_string}"
            exit 0
            ;;
        -n|--ngpus)
            ngpus="${2:-}"
            shift 2
            ;;
        --)                     # end of known args
            shift
            break
            ;;
        *)
            echo "unknown option ${1}" >&2
            echo "${usage_string}"
            exit 1
            break
            ;;
    esac
done

if is_int "${ngpus}" && [[ "${ngpus}" -gt 0 ]]; then
    echo "--ngpus requires a positve integer argument" >&2
    echo "${usage_string}" >&2
    exit 1
fi

[[ "${DEBUG:-}" ]] && echo "ngpus is ${ngpus} and rest of args are [${*}]"

if [[ "${NVIDIA_PYTORCH_VERSION:-}" ]]; then
    torchrun --standalone --no_python --nproc_per_node="${ngpus}" "${@}"
else
    mpirun --allow-run-as-root --bind-to none -np "${@}"
fi
